This session provides an overview of some of the object-oriented programming (OOP) concepts discussed in this module. Use this material to brush up on what you are expected to know in order to complete the last assignment.
The topics covered in this session are listed below:
(This is not a self-contained introduction to OOP, I'll just highlight some of the more important points; for a full revision use the slides from the previous sessions and the tutorials examples.)
A class is a programming construct that allows us to bundle together a set of data (or variables) along with a set of functions operating on that data.
The definition of a class consists of a block of statements that start with the keyword class
followed by the name of the class; optionally, you can add the keyword object as well:
class RationalNumber(object):
pass
# the body of this class is empty
An instance of this class or, in other words, an object of the type 'RationalNumber' is created by a statement like this:
a = RationalNumber()
# We can check the type of this variable 'a' by using
type(a)
#which tells us that the variable 'a' is of the type 'RationalNumber'
We have generated an object of the type 'RationalNumber', but there is no data associated with this object yet. Furthermore, there are no functions inside this class to perform operations on the data.
A rational number is essentially a fraction, which has a numerator and a denominator. We can use these as the attributes (i.e., data) of our class:
from __future__ import print_function
class RationalNumber(object):
def __init__(self, numerator, denominator):
self.numerator = numerator
self.denominator = denominator
def pprint(self):
print("DATA STORED:")
print("\tNumerator: {}".format(self.numerator))
print("\tDenominator: {}".format(self.denominator))
print("---------------------------")
The first function introduced above __init__ is the constructor of our class; it allows us to initialise objects that have specific values for their attributes. The name of this function is recognised by the Python interpreter. The second function pprint is simply used to display the data contained in a particular object. To test this class we can use the following lines of code:
# Create two instances:
# Note that we pass specific values to create our objects
a1 = RationalNumber(3, 4)
a2 = RationalNumber(7, 11)
# Display on the screen the data stored
# in the objects a1 and a2:
a1.pprint()
a2.pprint()
New objects of the type 'RationalNumber' are created by using the class name as if it was a function. The statement 'a1 = RationalNumber(3, 4)' does two things: (i) it first creates an empty object a1; (ii) and then applies the function __init__ to it, i.e. the statement a1.__init__(3, 4) is executed.
The first parameter of __init__ refers to the new object itself. On function call this first parameter is replaced by the object's instance. This rule applies to all methods of the class, not just to the special method __init__. By convention, this first parameter is named 'self'.
In the above example the function __init__ defines two attributes of the new object, 'numerator' and 'denominator'. The two functions inside the class, '__init__' and 'pprint', are known as the methods of the instance (the first argument of such functions is always the variable 'self').
In the above example we defined the method 'pprint' to display on the screen the values of the attributes stored in objects generated with our 'RationalNumber' class. While this is perfectly acceptable, it would be more natural if we can just type something like 'print(a1)' in order to produce the same effect (i.e., we would like to use the standard 'print' function). It turns out that in Python one can define and implement methods that will be called automatically when a standard Python syntax is invoked. This allows for a more natural use of objects than calling methods by name.
Such special methods start and end with a double underscore (you should never do this to name your own methods).
The __str__ special method:
from __future__ import print_function
class RationalNumber(object):
def __init__(self, numerator, denominator):
self.numerator = numerator
self.denominator = denominator
def __str__ (self):
s = "The fraction is: "
s += "{}/{}".format(self.numerator, self.denominator)
return s
# testing:
a3 = RationalNumber(2, 17)
print(a3)
Note that there are no print statements inside the __str__ method; we simply return a string which contains the information that we want to be displayed on the screen. When we call the standard 'print' function on an object from our class, Python automatically passes the string returned by the __str__ method.
The __repr__ special method:
from __future__ import print_function
class RationalNumber(object):
def __init__(self, numerator, denominator):
self.numerator = numerator
self.denominator = denominator
def __repr__(self):
return "{}/{}".format(self.numerator, self.denominator)
# testing:
a4 = RationalNumber(21, 13)
a4 # we can get the fraction by simply typing its name
# as it is typically the case with most variables in Python
Next, let's say we want our class 'RationalNumber' to contain a method that would allow us to add two rational numbers. Recall that
$$
\frac{p_1}{q_1} + \frac{p_2}{q_2} = \frac{p_1q_2 + p_2q_1}{q_1q_2}\,,
$$
which suggests that our method will have to return another object of the type 'RationalNumber' in which numerator $= p_1q_2 + p_2q_1$ and denominator $=q_1q_2$.
We would also like to be able to add a fraction to an integer; if we are to use the same formula, we'll have to set numerator = (our int) number and denominator $= 1$.
from __future__ import print_function
class RationalNumber(object):
def __init__(self, numerator, denominator):
self.numerator = numerator
self.denominator = denominator
def __repr__(self):
return "{}/{}".format(self.numerator, self.denominator)
def add(self, other):
p1, q1 = self.numerator, self.denominator
# Check first to see if the other number is an integer:
if isinstance(other, int):
p2, q2 = other, 1
# If the other number is also a proper RationalNumber:
else:
p2, q2 = other.numerator, other.denominator
return RationalNumber(p1*q2 + p2*q1, q1*q2)
# testting
b1 = RationalNumber(2, 3)
b2 = RationalNumber(1, 6)
b3 = 2
res1 = b1.add(b2) # b1+b2
res2 = b1.add(b3) # b1+b3
print(res1); print(res2)
So, it seems that our code does the job, but it would be much nicer if we could just write $c_1+c_2$, where $c_1$ and $c_2$ are objects from 'RationalClass'. At the moment we can't do this because the plus sign is not defined for elements from our class. Python provides a special method that once it is included in our class (and properly implemented) will remove this limitation; this special method is called __add__. We have already done all the hard work, the only thing left to do is change the name of the method.
The __add__ special method:
from __future__ import print_function
class RationalNumber(object):
def __init__(self, numerator, denominator):
self.numerator = numerator
self.denominator = denominator
def __repr__(self):
return "{}/{}".format(self.numerator, self.denominator)
def __add__(self, other):
p1, q1 = self.numerator, self.denominator
# Check first to see if the other number is an integer:
if isinstance(other, int):
p2, q2 = other, 1
# If the other number is also a proper RationalNumber:
else:
p2, q2 = other.numerator, other.denominator
return RationalNumber(p1*q2 + p2*q1, q1*q2)
# testing:
b1 = RationalNumber(2, 3)
b2 = RationalNumber(1, 6)
b3 = 2
res1 = b1 + b2; res2 = b1 + b3
print(res1); print(res2)
In Session 11 a list of several special methods was provided and some of them were illustrated on the example of the class 'Vector2D'. There are many more special methods (see the official documentation).
In the previous session we obtained the solution of a simple IVP in the form
$$
y(x) = -x - 1 + C e^x\,,
$$
where $C\in\mathbb{R}$ was a constant that was fixed by the initial condition. If $y(x_0) = y_0$ then
$$
-x_0 - 1 + Ce^{x_0} = y_0\quad\Longrightarrow\quad
C = (x_0 + y_0 + 1)e^{-x_0}\,,
$$
and hence our solution can be cast in the form
$$
y(x) = -x - 1 + (x_0 + y_0 + 1)e^{x-x_0}\,.
$$
While there is no difficulty in representing this expression by a Python function, it is sometimes more convenient to write it as a class (e.g., we can add a method in our class to plot the function -- a task that would be more laborious if we just implement the expression as a function).
import numpy as np
class mySoln(object):
def __init__(self, x0=0.0, y0=1.0):
self.x0 = x0
self.y0 = y0
def __call__(self, x):
x0, y0 = self.x0, self.y0 #use aliases to simplify the notation
return -x - 1 + (x0 + y0 + 1.0)*np.exp(x-x0)
# testing:
y = mySoln() # the default values for x0 and y0 will be used
# this 'y' is an object
x = 2.3; print(y(x)) # y(x) has the usual meaning -- thanks to the __call__ method
y = mySoln(0.0, 6.0) # now we use x0 = 0.0 and y0=6.0 --> y is a different object
x = 2.3; print(y(x)) # evaluate the new object at the same 'x'
The special method __call__ makes it possible to call our object as an ordinary function. Instances with such a method are said to be callable objects. In Python, any function is callable by default (but this is not true for objects). We can test if a given object is callable by using the syntax included below:
if callable(y):
print("Our object is callable")
Below, there are a couple of longer examples that illustrate the use of classes in mathematical programming:
EXAMPLE I:
Consider the motion of a projectile. The question is to find the position of the projectile at a given moment of time after its launch, and plot its trajectory up to that point.
The problem can be solved in a number of different ways, but an OOP approach provides a particularly neat solution. Recall that the parametric equations for the trajectory of a projectile are given by the formulae
$$
x = x_0+(v_0\cos\alpha)t\,,\qquad
y = y_0+(v_0\sin\alpha)t - \frac{1}{2}gt^2\,,
$$
where $v_0$ is the initial speed (or velocity), $\alpha$ is the angle of projection, $g\simeq 9.8$
is the gravitational acceleration, $t$ represents time, and $(x_0, y_0)$ represent the Cartesian coordinates of the location from where the projectile is launched. For simplicity, we shall assume that $x_0=0$, but will allow for $y_0\neq 0$.
We construct a class Projectile which contains two functions (i.e., methods) that take care of the aforementioned tasks. One of these methods is called 'update' and calculates the new position of the projectile based on the argument 'time' (provided by the user). As this parameter is arbitrary, there are two cases to consider: (i) if 'time' < time of flight of projectile, then we simply use the above formulae to find the current values of $x$ and $y$; (ii) if 'time > time of flight then the projectile is already on the ground (i.e., $y=0$) and we only need to figure out what is $x$ (essentially, the range of the projectile). To distinguish between these two cases we need a conditional ('if') statement. The other method is called 'draw' and is used to plot the trajectory of the projectile -- what the plot will look like depends pretty much on the two cases already mentioned. We use an attribute called 'tc', which is (possibly) updated by the 'update' method, and we make sure that this method is called prior to calling the 'draw' method. You may want to study the code and experiment with it in order to understand what it does.
from __future__ import print_function
from math import sin, cos, radians
from matplotlib import pyplot as plt
import numpy as np
# An example of class with several methods:
class Projectile(object):
def __init__(self, angle, velocity, height, g=9.8):
self.xpos = 0.0
self.ypos = height
theta = radians(angle)
self.xvel = velocity*cos(theta)
self.yvel = velocity*sin(theta)
self.g = g # gravitational acceleration
self.tc = 0.0 # current time
self.elevation = height # initial elevation
def update(self, time):
# Returns the current position
t, g = time, self.g
xpos_new = (self.xvel)*t
ypos_new = self.elevation +(self.yvel)*t - 0.5*g*(t**2)
# if particle still in flight:
if ypos_new > 0:
self.xpos = xpos_new
self.ypos = ypos_new
self.tc = time
# if the particle has already returned to the ground:
else:
tmp = np.linspace(0.0, time, 5000)
ytmp = self.elevation + (self.yvel)*tmp - 0.5*g*(tmp*tmp)
i = 0
while ytmp[i] >= 0:
i += 1
self.xpos = self.xvel*tmp[i]
self.ypos = 0
self.tc = tmp[i]
# note that this method does NOT return anything
# it simply updates the original attributes
def getY(self):
return self.ypos
def getX(self):
return self.xpos
def draw(self, time, N, color='b'):
g = self.g
#xvals = []
#yvals = []
if self.tc < time:
tmp = np.linspace(0.0, self.tc, N)
else:
tmp = np.linspace(0.0, time, N)
xvals = (self.xvel)*tmp
yvals = self.elevation + (self.yvel)*tmp - 0.5*g*(tmp*tmp)
%matplotlib notebook
plt.plot(xvals, yvals, color, xvals[-1], yvals[-1],'or')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.show()
#EndOfClass definition
# The two functions included below are for testing the above class
def getInputs():
'''
Can use a separate function for defining various parameters
'''
a = 66
v = 2.3
h = 1.0
t = 0.5
return a, v, h, t
def main():
# Initialize the variables on the left:
angle, vel, h0, time = getInputs()
# Create an instance of the Projectile() class:
projectile = Projectile(angle, vel, h0)
# Call the update() method to find the current values
# of the attributes definied in the constructor:
projectile.update(time)
# Print on the screen the horizontal distance traveled:
print("\nHorizonatal distance traveled: {0:0.4f} meters.".format(projectile.getX()))
projectile.draw(time, 1000)
main()
EXAMPLE II:
Let's consider a class that defines an object for planar triangles. Such triangles can be defined by specifying three arbitrary points. We want our class to have a method for calculating the area of such triangles, and another method for sketching them.
From CFM2104 we know that, if $A$, $B$, and $C$ are the vertices of a triangle then its area is given by
$$
\frac{1}{2}|\overrightarrow{AB} \wedge \overrightarrow{AC}|\,.
$$
That is, we take the vector product of $\overrightarrow{AB}$ and $\overrightarrow{AC}$, calculate its magnitude, and divide the result by $2$.
from __future__ import print_function
import numpy as np
from matplotlib import pyplot as plt
#
# we use the cross-product function from numpy
# and the 'abs' function that calculates the magnitude of a vector/numpy array
class Triangle(object):
""" Class the triangles defined by 3 points"""
def __init__(self, A, B, C):
# vertices:
self.A = np.array(A)
self.B = np.array(B)
self.C = np.array(C)
# the sides (as vectors):
self.a = self.C - self.B # vector BC
self.b = self.C - self.A # vector AC
self.c = self.B - self.A # vector AB
def area(self):
return 0.5*np.abs(np.cross(self.b, self.c))
def draw(self):
pass
# testing:
# first choose a triangle whose area you can calculate by inspection
# E.G., a right-angled triangle will do:
triang1 = Triangle([0.0, 0.0], [2.0, 0.0], [0.0, 3.0]) # its area is (2 x3 )/2 = 3
# now choose another triangle:
triang2 = Triangle([0.0, 0.0], [1.0, 1.1], [2.0, 3.0])
print("Area triangle 1 is: ", triang1.area())
print("Area triangle 2 is: {:1.3f}".format(triang2.area()))
There is a problem with this code, even though the result we have above is correct.
Let's change B and see if the area changes.
triang1.B = [10.0, 0.0] # the new area would have to change to 15
triang1.area()
Clearly, the result has not changed. The reason for this failure is due to the fact that the attributes 'a' and 'c' (both of which depend on B) are not automatically updated. This example illustrates the dangers of updating attributes outside a class. In general we can prevent this from happening by making certain attributes either protected or private.
For protected attributes the Python convention is to prefix the attribute name with an underscore. However, doing that does not change the class users' ability to change the attribute from outside the class. The single underscore is simply a marker letting them know not to access or change the attribute from outside the class.
For private attributes, we prefix the attribute name with a double underscore. This renders the corresponding attribute inaccessible from outside the class. If we follow this route we need to provide getter and setter methods (if we still want to access that attribute from outside the class). We discussed accessor and mutator methods in Session 12, and have used the concept in several other sessions afterwards.
Let's revisit the above 'Triangle' class, by writting setter and getter methods for the vertex B
(in principle, we would have to provide a pair of such methods for each vertex).
from __future__ import print_function
import numpy as np
from matplotlib import pyplot as plt
class Triangle(object):
""" Class the triangles defined by 3 points"""
def __init__(self, A, B, C):
# vertices:
self._A = np.array(A)
self._B = np.array(B)
self._C = np.array(C)
# the sides (as vectors):
self._a = self._C - self._B # vector BC
self._b = self._C - self._A # vector AC
self._c = self._B - self._A # vector AB
def area(self):
return 0.5*np.abs(np.cross(self._b, self._c))
def set_B(self, B):
self._B = B
self._a = self._C - self._B
self._c = self._B - self._A
def get_B(self):
return self._B
def draw(self):
pass
B = property(fget = get_B, fset = set_B, fdel='') # see below for explanations
#testing
triang1 = Triangle([0.0, 0.0], [2.0, 0.0], [0.0, 3.0])
triang1.B = [10.0, 0.0] # the new area would have to change to 15
triang1.area()
We have used the built-in function property to link an attribute to its setter and getter methods. Of course, we can avoid using this function, but then we would have to manually use the set_B and get_B methods outside of the class. The syntax for the above function is
Name of the attribute = property(fget = "Name of your getter", fset = "Name of your setter", fdel = "Name of your destructor",...)
where the keywords are shown in boldface and everything else will depend on what you name your attribute, setter, getter, etc. The dots at the end indicate that there is one more optional argument (which I am not mentioning).
The code without the 'property' function is included below:
from __future__ import print_function
import numpy as np
from matplotlib import pyplot as plt
class Triangle(object):
""" Class the triangles defined by 3 points"""
def __init__(self, A, B, C):
# vertices:
self._A = np.array(A)
self._B = np.array(B)
self._C = np.array(C)
# the sides (as vectors):
self._a = self._C - self._B # vector BC
self._b = self._C - self._A # vector AC
self._c = self._B - self._A # vector AB
def area(self):
return 0.5*np.abs(np.cross(self._b, self._c))
def set_B(self, B):
self._B = np.array(B)
self._a = self._C - self._B
self._c = self._B - self._A
def get_B(self):
return self._B
def draw(self):
# you should implement this yourselves
# (to check your understanding....)
pass
#testing
triang3 = Triangle([0.0, 0.0], [2.0, 0.0], [0.0, 3.0])
B = [10.0, 0.0] # the new area would have to change to 15
triang3.set_B(B) # without using the 'property' function
triang3.area()
triang3.get_B()
Attributes specified in the class definition (outside the constructor) are called static attributes (in some texts they are referred to as class attributes). To access such attributes we must prefix them by the class name instead of 'self'. Also, they are shared among all instances generated by the corresponding class. This topic was discussed in Session 12.
Class attributes are particularly useful to simulate default values, and can be used if values
have to be reset. In the example below TOL, MAX_ITER, h are such variables (which admit default values).
from __future__ import print_function
class Newton(object):
"""
A simple class for the Newton-Raphson Method
"""
TOL = 1.0E-7
MAX_ITER = 100
h = 1.0E-5
def __init__(self, f):
self.f = f
# the derivative is approximated using finite differences:
def dfdx(self, x):
f = self.f
h = Newton.h
return (f(x + h) - f(x))/float(h)
# this is the actual solver:
def solve(self, x):
f = self.f
n = 0
while abs(f(x)) > Newton.TOL and n < Newton.MAX_ITER:
x = x - f(x)/self.dfdx(x)
n += 1
return x, n, f(x)
# testing:
def f(x): return (x-1.0)*(x-5.0) # we are trying to solve f(x) = 0
N1 = Newton(f) # create an object N1 of type Newton
xguess = 0.3; res = N1.solve(xguess) # the method solve() requires a starting value (guess)
#xguess = 3.2; res = N1.solve(xguess) # if you want the other root
print("The root is: {:2.4f}".format(res[0]))
print("Number of Iterations: {}".format(res[1]))
If we want to change some of the static attributes, that's very easy.
N1 = Newton(f) # TOL is the default/original value
Newton.TOL = 1.0E-1 # change TOL to 0.1 (bad idea...)
N2 = Newton(f) # instantiate a Newton object with this new TOL
res = N2.solve(0.3) # use the solve method on the new object
# Display the results on the screen to convince
# yourself that N2 is not as accurate as N1:
print("The root is: {:2.4f}".format(res[0]))
print("Number of Iterations: {}".format(res[1]))
In this section we review some key concepts in OOP: abstract classes, sub-classes and inheritance. We spent considerable time on these concepts in Sections 13 and 14.
Inheritance is a relationship between a more general class (called the superclass or the parent class) and a more specialised class (called the subclass or child class). The subclass inherits data and behaviour from the superclass.
For example, every car is a vehicle. Cars and vehicles share a number of common features (they both have an engine, they both have wheels, etc). Obviously the term 'vehicle' is broader than the term 'car': a motorbike is a vehicle, but is not a car, etc. We can say that the class 'Car' inherits from the 'Vehicle' class; in this relationship the latter is the superclass and the former is the subclass.
A subclass constructor can only define the instance variables of the subclass. More often than not, the superclass instance variables must be defined as well; so the constructor of the subclass must explicitly call the superclass constructor. Because the constructor of both classes have the same name, we must be careful about the syntax. Recall the following example from Session 13:
from __future__ import print_function
class Counter(object):
def __init__(self): # superclass constructor
self.value = 0
def increment(self):
self.value += 1
return self.value
class CustomCounter(Counter):
def __init__(self, size): # subclass constructor
Counter.__init__(self) # call superclass constructor
self.stepsize = size
def increment(self):
self.value += self.stepsize # overriding the parent behaviour
return self.value
#
#-- Testing ------------------------------------------------------------------------
C1 = Counter()
C2 = CustomCounter(3)
for j in range(5):
print(C1.increment(), C2.increment())
The subclass inherits the methods in the superclass. We can override this by specifying a new implementation in the subclass. In addition, new methods can be included in the subclass.
We illustrate the idea of abstract class and inheritance by using two of the integration methods reviewed in the previous section: the rectangle method and the midpoint method. The code for the former has already been mentioned/tested in Session 19 (it is included below for your convenience):
def integrateREC(fname, xmin, xmax, intervals):
""" Integrate by using the rectangle rule."""
h = (xmax - xmin) / float(intervals)
total = 0.0
# Perform the integration:
x = xmin
for interval in range(intervals):
# Add the area in the trapeziod for this slice:
total += h * fname(x)
# Move to the next slice:
x += h
return total # this returns the approx. of the integral
The midpoint rule was left as an exercise -- the code can be found below:
def integrateMID(fname, xmin, xmax, intervals):
""" Integrate by using the mid-point rule."""
h = (xmax - xmin) / float(intervals)
total = 0.0
# This is the only thing we have to change:
x = xmin + 0.5*h # midpoint for the interval (x0, x1)
for interval in range(intervals):
# Add the area in the trapeziod for this slice:
total += h * fname(x)
# Move to the next slice:
x += h # the next midpoint, etc
return total # this returns the approx. of the integral
# Testing:
import numpy as np
def f(x): return np.sin(x)
integrateMID(f, 0.0, np.pi, 50)
integrateREC(f, 0.0, np.pi, 50) # check the rectangle method as well
The class 'Integrate' that appears below is an abstract class, it is not meant to be called on its own; it simply acts as a template. However, 'Rectangle' and 'MidPoint' are its sub-classes, and these are meant to inherit all the attributes and methods of the parent class. Note that the method 'initial' is not implemented in 'Integrate' because the child classes each requires a different form. Furthermore, the number of rectangles has a default value and is defined as a class attribute; recall that you can change this type of attribute by using statements like:
Integrate.N = 100, etc.
from __future__ import print_function
#
# inheritance in the context of integration
# this is the abstract class:
class Integrate(object):
N = 50 #class attribute
def __init__(self, f, xmin, xmax):
self.f = f
self.xmin = xmin
self.xmax = xmax
self.h = (xmax - xmin)/float(Integrate.N)
def initial(self):
raise NotImplementedError()
def compute(self):
x = self.initial()
total = 0.0
for interval in range(Integrate.N):
total += self.h*f(x)
x += self.h
return total
# child class of 'Integrate':
class Rectangle(Integrate):
def initial(self):
return self.xmin
# another child of 'Integrate':
class MidPoint(Integrate):
def initial(self):
return self.xmin + 0.5*self.h
#----------------------------------------------
# Testing:
import numpy as np
def f(x): return np.sin(x)
I1 = Rectangle(f, 0.0, np.pi)
I2 = MidPoint(f, 0.0, np.pi)
print("Rectangle Rule: ", I1.compute())
print("Midpoint Rule: ",I2.compute())
# compare what is displayed with the above results
Each child class overrides the parent 'initial' method. Note that we do not call the superclass constructor in the subclasses because in this case the subclasses have no new instance variables (in general, this is the exception rather than the rule). See also the example at the end of Session 13 where a similar situation is presented (the parent class 'Diff' with the children 'Forward1' and 'Central4').
The classes 'Rectangle' and 'MidPoint' above are relatively simple because both numerical methods are based on very similar formulae. This would not be quite as straightforward if we tried to include the 'Trapezium' and the 'Simpson' rules. One possible fix would be to write the code for the latter so that it resembles the other methods:
def integrateSIMP(fname, xmin, xmax, intervals):
""" Integrate by using SIMPSON' Rule."""
h = (xmax - xmin) / float(intervals)
total = 0.0
x = xmin # initialise x
for interval in range(intervals/2):
# Add the area below parabola for this slice:
total += (h/3.0) * (fname(x) + 4.0*fname(x+h) + fname(x+2.0*h)) # see Session 4
# Move to the next slice:
x += 2.0*h # x is incremented by 2*h
return total # this returns the approx. of the integral
# Testing:
import numpy as np
def f(x): return np.sin(x)
integrateSIMP(f, 0.0, np.pi, 30)
Since Euler's and Heun's methods compute the approximations iteratively, they can be implemented by using a suitable generator. This is illustrated below:
from __future__ import print_function
import numpy as np
from matplotlib import pyplot as plt
class ODEsolver(object):
def __init__(self, f, xmin, xmax, y0, N):
self.f = f
self.xmin = xmin
self.y0 = y0
self.grid = np.linspace(xmin, xmax, N)
self.h = (xmax - xmin)/float(N)
# implementation of the FE method as a generator:
def generate(self):
x, y = self.xmin, self.y0
yield x, y # note the yield statement
for xtmp in self.grid[1:]:
y = y + (self.h)*self.step(self.f, xtmp, y)
x = xtmp
yield x, y # note the yield statement
def solve(self):
self.solution = np.array(list(self.generate()))
def plot(self):
plt.plot(self.solution[:, 0], self.solution[:, 1])
plt.xlabel('x')
plt.ylabel('y')
plt.show()
def step(self, f, x, y):
raise NotImplementedError()
class ForwardEuler(ODEsolver):
def step(self, f, x, y): # overrides the 'step' method above
return f(x, y)
# Testing ------------------------------------------
%matplotlib notebook
def f(x, y): return x + y
euler = ForwardEuler(f, 0.0, 7.0, 1.0, 10000)
euler.solve()
euler.plot()